syntax = "proto3";
option cc_enable_arenas = true;

package tensorflow.boosted_trees.trees;

// TreeNode describes a node in a tree.
message TreeNode {
  oneof node {
    Leaf leaf = 1;
    DenseFloatBinarySplit dense_float_binary_split = 2;
    SparseFloatBinarySplitDefaultLeft sparse_float_binary_split_default_left =
        3;
    SparseFloatBinarySplitDefaultRight sparse_float_binary_split_default_right =
        4;
    CategoricalIdBinarySplit categorical_id_binary_split = 5;
    CategoricalIdSetMembershipBinarySplit
        categorical_id_set_membership_binary_split = 6;
    ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
    ObliviousCategoricalIdBinarySplit oblivious_categorical_id_binary_split = 8;
  }
  TreeNodeMetadata node_metadata = 777;
}

// TreeNodeMetadata encodes metadata associated with each node in a tree.
message TreeNodeMetadata {
  // The gain associated with this node.
  float gain = 1;

  // The original leaf node before this node was split.
  Leaf original_leaf = 2;

  // The original layer of leaves before that layer was converted to a split.
  repeated Leaf original_oblivious_leaves = 3;
}

// Leaves can either hold dense or sparse information.
message Leaf {
  oneof leaf {
    // See third_party/tensorflow/contrib/decision_trees/
    // proto/generic_tree_model.proto
    // for a description of how vector and sparse_vector might be used.
    Vector vector = 1;
    SparseVector sparse_vector = 2;
  }
}

message Vector {
  repeated float value = 1;
}

message SparseVector {
  repeated int32 index = 1;
  repeated float value = 2;
}

// Split rule for dense float features.
message DenseFloatBinarySplit {
  // Float feature column and split threshold describing
  // the rule feature <= threshold.
  int32 feature_column = 1;
  // If feature column is multivalent, this holds the index of the dimension
  // for the split. Defaults to 0.
  int32 dimension_id = 5;
  float threshold = 2;

  // Node children indexing into a contiguous
  // vector of nodes starting from the root.
  int32 left_id = 3;
  int32 right_id = 4;
}

// Split rule for sparse float features defaulting left for missing features.
message SparseFloatBinarySplitDefaultLeft {
  DenseFloatBinarySplit split = 1;
}

// Split rule for sparse float features defaulting right for missing features.
message SparseFloatBinarySplitDefaultRight {
  DenseFloatBinarySplit split = 1;
}

// Split rule for categorical features with a single feature Id.
message CategoricalIdBinarySplit {
  // Categorical feature column and Id describing
  // the rule feature == Id.
  int32 feature_column = 1;
  int64 feature_id = 2;

  // Node children indexing into a contiguous
  // vector of nodes starting from the root.
  int32 left_id = 3;
  int32 right_id = 4;
}

// Split rule for categorical features with a set of feature Ids.
message CategoricalIdSetMembershipBinarySplit {
  // Categorical feature column and Id describing
  // the rule feature ∈ feature_ids.
  int32 feature_column = 1;
  // Sorted list of Ids in the set.
  repeated int64 feature_ids = 2;

  // Node children indexing into a contiguous
  // vector of nodes starting from the root.
  int32 left_id = 3;
  int32 right_id = 4;
}

// Split rule for dense float features in the oblivious case.
message ObliviousDenseFloatBinarySplit {
  // Float feature column and split threshold describing
  // the rule feature <= threshold.
  int32 feature_column = 1;
  float threshold = 2;
  // We don't store children ids, because either the next node represents the
  // whole next layer of the tree or starting with the next node we only have
  // leaves.
}

// Split rule for categorical features with a single feature Id in the oblivious
// case.
message ObliviousCategoricalIdBinarySplit {
  // Categorical feature column and Id describing the rule feature == Id.
  int32 feature_column = 1;
  int64 feature_id = 2;
  // We don't store children ids, because either the next node represents the
  // whole next layer of the tree or starting with the next node we only have
  // leaves.
}

// DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
// Note that each node id is implicitly its index in the list of nodes.
message DecisionTreeConfig {
  repeated TreeNode nodes = 1;
}

message DecisionTreeMetadata {
  // How many times tree weight was updated (due to reweighting of the final
  // ensemble, dropout, shrinkage etc).
  int32 num_tree_weight_updates = 1;

  // Number of layers grown for this tree.
  int32 num_layers_grown = 2;

  // Whether the tree is finalized in that no more layers can be grown.
  bool is_finalized = 3;
}

message GrowingMetadata {
  // Number of trees that we have attempted to build. After pruning, these
  // trees might have been removed.
  int64 num_trees_attempted = 1;
  // Number of layers that we have attempted to build. After pruning, these
  // layers might have been removed.
  int64 num_layers_attempted = 2;

  // Sorted list of column handlers that have been used in at least one split
  // so far.
  repeated int64 used_handler_ids = 3;
}

// DecisionTreeEnsembleConfig describes an ensemble of decision trees.
message DecisionTreeEnsembleConfig {
  repeated DecisionTreeConfig trees = 1;
  repeated float tree_weights = 2;
  repeated DecisionTreeMetadata tree_metadata = 3;

  // Metadata that is used during the training.
  GrowingMetadata growing_metadata = 4;
}
